// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <iostream>

#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/tensor.h"

namespace cinn {
namespace frontend {

struct ComputationContext;

class CinnComputation {
 public:
  struct CompileOptions : public hlir::framework::CompilationContext {
    bool use_decomposer = false;
    bool do_prerun = true;
    bool use_default_passes = true;
    std::vector<std::string> passes;
  };

  inline static CompileOptions DefaultCompileOptions() {
    CompileOptions options;
    options.with_instantiate_variables = true;
    options.use_decomposer = false;
    options.passes = {};
    options.do_prerun = true;
    options.use_default_passes = true;
    return options;
  }

  /**
   * build program from NetBuilder, then compile it. NetBuilder is normally
   * NetBuilder or CINNBuilder.
   * @param target the target to run the program
   * @param builder program builder (NetBuilder or CINNBuilder)
   * @param options CompileOptions, config the compilation steps
   * @param outputs program output variables, if outputs is empty, then the
   * output variable of the last instruction of the program is used
   * @param stream CUDA stream, the value is meaningful only when target is
   * NVGPU
   * @return shared_ptr pointing to CinnComputation instance
   */
  static std::shared_ptr<CinnComputation> BuildAndCompile(
      const Target &target,
      NetBuilder &builder,  // NOLINT
      const CompileOptions &options = DefaultCompileOptions(),
      const std::vector<Variable> &outputs = {},
      void *stream = nullptr);
  /**
   * compile the program
   * @param target the target to run the program
   * @param program program (usually generated by a Builder, or converted from
   * Paddle model)
   * @param options CompileOptions, config the compilation steps
   * @param outputs program output variables, if outputs is empty, then the
   * output variable of the last instruction of the program is used
   * @param stream CUDA stream, the value is meaningful only when target is
   * NVGpu
   * @return shared_ptr pointing to CinnComputation instance
   */
  static std::shared_ptr<CinnComputation> Compile(
      const Target &target,
      Program &program,  // NOLINT
      const CompileOptions &options = DefaultCompileOptions(),
      const std::vector<Variable> &outputs = {},
      void *stream = nullptr);
  /**
   * convert a paddle model to program, then compile it.
   * @param target the target to run the program
   * @param model_path the path of the paddle model
   * @param input_names input variable names of paddle model
   * @param input_shapes input variable shapes of paddle model
   * @param params_combined whether params are stored combined
   * @param options CompileOptions, config the compilation steps
   * @param stream CUDA stream, the value is meaningful only when target is
   * NVGpu
   * @return shared_ptr pointing to CinnComputation instance
   */
  static std::shared_ptr<CinnComputation> CompilePaddleModel(
      const Target &target,
      const std::string &model_path,
      const std::vector<std::string> &input_names,
      const std::vector<hlir::framework::shape_t> &input_shapes,
      bool params_combined,
      const CompileOptions &options = DefaultCompileOptions(),
      void *stream = nullptr);

  /**
   * get all variable names in the program
   */
  std::vector<std::string> GetAllTensorNames();

  /**
   * get tensor by name
   * @param name tensor name
   */
  hlir::framework::Tensor GetTensor(const std::string &name);

  /**
   * get input tensors
   */
  std::vector<hlir::framework::Tensor> GetInputTensors();

  /**
   * get output tensors
   */
  std::vector<hlir::framework::Tensor> GetOutputTensors();

  /**
   * set the data of a tensor from user specified buffer.
   * if tensor is in NVGPU device memory, cudaMemcpy is used.
   * @param t the tensor
   * @param data address of the memory buffer to store tensor's data
   * @param size size of the memory buffer
   */
  void SetTensorData(hlir::framework::Tensor &t,  // NOLINT
                     void *data,
                     size_t size);

  /**
   * set the data of a tensor (specified by it's name) from user specified
   * buffer. if tensor is in NVGPU device memory, cudaMemcpy is used.
   * @param tname name of the tensor
   * @param data address of the memory buffer to store tensor's data
   * @param size size of the memory buffer
   */
  void SetTensorData(const std::string &tname, void *data, size_t size);

  /**
   * copy the data of a tensor to user specified buffer.
   * if tensor is in NVGPU device memory, cudaMemcpy is used.
   * @param t the tensor
   * @param data address of the memory buffer to store tensor's data
   * @param size size of the memory buffer
   */
  void GetTensorData(hlir::framework::Tensor &t,  // NOLINT
                     void *data,
                     size_t size);
  /**
   * copy the data of a tensor (specified by it's name) to user specified
   * buffer. if tensor is in NVGPU device memory, cudaMemcpy is used.
   * @param tname name of the tensor
   * @param data address of the memory buffer to store tensor's data
   * @param size size of the memory buffer
   */
  void GetTensorData(const std::string &tname, void *data, size_t size);

  /**
   * run the compiled program
   */
  void Execute(
      const std::map<std::string, cinn_pod_value_t> *name2podargs = nullptr);

 private:
  std::shared_ptr<ComputationContext> context_;
};

}  // namespace frontend
}  // namespace cinn
